import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_loader import *

import numpy as np
from Utils import *
import time
from torch.optim import lr_scheduler
from torchvision import models
import json

from MoE_RIM import MoE_RIM
import miniImagenetOneShot
import tieredImagenetOneShot

randomseed = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available(): 
    torch.cuda.set_device(device)

random.seed(randomseed)
np.random.seed(randomseed)
torch.manual_seed(randomseed)

args = {
    'model' : 'BDC',#'BDC'
    'dataset': 'miniImagenet',#'miniImagenet', 'tieredImagenet'(tieredImagenet setting is also available for Dermnet dataset)
    'root': './mini-imagenet/',#Your path to dataset
    'classifier': 'LogisticRegression',#'NN', #'RF', #'SVM', #'LogisticRegression'
    'n_shot':  5,
    'n_way':  5,
    'n_query':  5,
    'n_gaus': 100, 
    'edge_dim': 128,
    'g_dim': 256,
    'batchSize': 8,
    'overlap': 0,
    
    'lr': 0.0001,
    'seed': randomseed,
    'optimizer': 'Adam',
    'lr_decay': 0.9,
    'weight_decay': 0.0003,
    'epoch': 1000,
    'Train_epoch' : 1,
    'lambda1': 0.0005,
    'lambda2': 0.0004,
    'out_dir':"./outfilesmain/",
    'TV_lr': 0.0002, 
    'valruns': 200,
    'n_base': 10, #base nums
    'experts': 3, #moe expert nums
    'device' : device,
    'omega' : 0.5, 
    'workers': 1,
    'ImageSize' : 224,
}

def main():
    #set random seed
    np.random.seed(args['seed'])
    if args['dataset'] == 'tieredImagenet':
        base_images, base_labels = tieredImagenetOneShot.get_tiered_base_class_images(args['root'], args['n_base'], args['experts'], image_size = args['ImageSize'], num_per_base = args['num_per_base'])
        # Print results
        print(f"Generated base class images tensor with shape: {base_images.shape}")
        print(f"Number of classes: {len(base_labels)}")
        print(f"Example class names: {base_labels[:5]}")
        dataTrain = tieredImagenetOneShot.TieredImagenetOneShotDataset(dataroot=args['root'],
                                                                type = 'train',
                                                                nEpisodes = args['batchSize'] * args['epoch'],
                                                                classes_per_set=args['n_way'],
                                                                samples_per_class=args['n_shot'],
                                                                samples_per_query=args['n_query'],
                                                                ImageSize = args['ImageSize'])

        dataVal = tieredImagenetOneShot.TieredImagenetOneShotDataset(dataroot=args['root'],
                                                                type = 'val',
                                                                nEpisodes = args['batchSize'] * args['valruns'],
                                                                classes_per_set=args['n_way'],
                                                                samples_per_class=args['n_shot'],
                                                                samples_per_query=args['n_query'],
                                                                ImageSize = args['ImageSize'])

        dataTest = tieredImagenetOneShot.TieredImagenetOneShotDataset(dataroot=args['root'],
                                                                type = 'test',
                                                                nEpisodes = args['batchSize'] * args['epoch'],
                                                                classes_per_set=args['n_way'],
                                                                samples_per_class=args['n_shot'],
                                                                samples_per_query=args['n_query'],
                                                                ImageSize = args['ImageSize'])
    elif args['dataset'] == 'miniImagenet':
        base_images, base_labels = get_base_class_images(args['root'], args['n_base'], args['experts'], image_size = args['ImageSize'])
        # Print results
        print(f"Generated base class images tensor with shape: {base_images.shape}")
        print(f"Number of classes: {len(base_labels)}")
        print(f"Example class names: {base_labels[:5]}")
        dataTrain = miniImagenetOneShot.miniImagenetOneShotDataset(dataroot=args['root'],
                                                                type = 'train',
                                                                nEpisodes = args['batchSize'] * args['epoch'],
                                                                classes_per_set=args['n_way'],
                                                                samples_per_class=args['n_shot'],
                                                                samples_per_query=args['n_query'],
                                                                ImageSize = args['ImageSize'])

        dataVal = miniImagenetOneShot.miniImagenetOneShotDataset(dataroot=args['root'],
                                                                type = 'val',
                                                                nEpisodes = args['batchSize'] * args['valruns'],
                                                                classes_per_set=args['n_way'],
                                                                samples_per_class=args['n_shot'],
                                                                samples_per_query=args['n_query'],
                                                                ImageSize = args['ImageSize'])

        dataTest = miniImagenetOneShot.miniImagenetOneShotDataset(dataroot=args['root'],
                                                                type = 'test',
                                                                nEpisodes = args['batchSize'] * args['epoch'] * 2,
                                                                classes_per_set=args['n_way'],
                                                                samples_per_class=args['n_shot'],
                                                                samples_per_query=args['n_query'],
                                                                ImageSize = args['ImageSize'])

    
    train_loader = torch.utils.data.DataLoader(dataTrain, batch_size=args['batchSize'],
                                                   shuffle=True, num_workers=4)
    val_loader = torch.utils.data.DataLoader(dataVal, batch_size=args['batchSize'],
                                                        shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(dataTest, batch_size=args['batchSize'],
                                                      shuffle=True, num_workers=4)
    


    print(f"Base class images shape: {base_images.shape}")
    print(f"Number of classes: {len(base_labels)}")
    print(f"Class labels: {base_labels}")
    print(f"Number of images per class: {base_images.shape[1]}")

    

    if args['model'] == 'BDC':
        model = MoE_RIM(args, base_images, base_labels).to(device)        

    model = train_model(
            train_loader=train_loader,
            val_loader=val_loader,
            model = model,
            n_way = args['n_way'],
            k_shot = args['n_shot'],
            n_query = args['n_query'],
            epochs = args['Train_epoch'],
            lr = args['lr'],
            kl_weight=1.0,
            args = args
        )
    acc = evaluate(model, test_loader, args['n_way'], args['n_shot'], args['n_query'], args)
    out_dir = args['out_dir']
    with open(f'{out_dir}/results.txt', 'a') as f:
        f.write(f'{args}, acc: {acc}\n')



    
if __name__ == '__main__':
    
    main()
